{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e14d3e0a-6837-48ad-b971-61b72f9ea397",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据集，这里是直接联网读取，也可以通过下载文件，再读取\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "abcb3995-1cbe-44f1-8eeb-2f921a7e8b4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "train_data.columns = [\"document\", \"summary\"]\n",
    "train_ds = Dataset.from_pandas(train_data.iloc[:-2000])\n",
    "eval_ds = Dataset.from_pandas(train_data.iloc[-2000:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7c252708-778a-427b-a3b1-60a99c77aadb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
     ]
    }
   ],
   "source": [
    "from transformers import T5Tokenizer\n",
    "\n",
    "pretrained_model = \"/home/lyz/hf-models/IDEA-CCNL/Randeng-T5-784M-MultiTask-Chinese/\"\n",
    "\n",
    "special_tokens = [\"<extra_id_{}>\".format(i) for i in range(100)]\n",
    "tokenizer = T5Tokenizer.from_pretrained(\n",
    "    pretrained_model,\n",
    "    do_lower_case=True,\n",
    "    max_length=512,\n",
    "    truncation=True,\n",
    "    additional_special_tokens=special_tokens,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "25df70a4-dba8-4745-a651-14cfde26838d",
   "metadata": {},
   "outputs": [],
   "source": [
    "head_prefix = \"意图识别任务: \"\n",
    "tail_prefix = \" Travel-Query/Music-Play/FilmTele-Play/Video-Play/Radio-Listen/HomeAppliance-Control/Weather-Query/Alarm-Update/Calendar-Query/TVProgram-Play/Audio-Play/Other\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "be466865-ebf7-4491-a093-fd4b73ceb8e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_input_length = 60\n",
    "max_target_length = 20\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    model_inputs = tokenizer(head_prefix + examples[\"document\"], max_length=max_input_length, truncation=True)\n",
    "\n",
    "    # Setup the tokenizer for targets\n",
    "    labels = tokenizer(text_target=examples[\"summary\"], max_length=max_target_length, truncation=True)\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a4ba7d2a-d2e2-443f-b2be-c67e0bab257a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60d8a39414c447bfb8351046c64c0074",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/10100 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62339fe42d5849439dd08329fc7781af",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/2000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_tokenized_id = train_ds.map(preprocess_function, remove_columns=train_ds.column_names)\n",
    "eval_tokenized_id = eval_ds.map(preprocess_function, remove_columns=train_ds.column_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6aa17b2a-ea47-4af7-b55c-ddb9a43bb01f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model,  device_map=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "67c28375-742c-4cd9-b9b9-524e2116c0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
    "\n",
    "batch_size = 4\n",
    "args = Seq2SeqTrainingArguments(\n",
    "    \"t5-finetuned\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    weight_decay=0.01,\n",
    "    save_total_limit=3,\n",
    "    gradient_accumulation_steps=10,\n",
    "    do_eval=True,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=50,\n",
    "    num_train_epochs=5,\n",
    "    save_steps=50,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True,\n",
    "    load_best_model_at_end=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6a63025b-5f49-42f2-b829-e7c8f8865505",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Seq2SeqTrainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=train_tokenized_id,\n",
    "    eval_dataset=eval_tokenized_id,\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "145bce5e-ecd2-4d79-8186-ca93ad96b139",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='151' max='1260' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 151/1260 12:17 < 1:31:32, 0.20 it/s, Epoch 0.59/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.695113</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.314147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.245422</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
      "Non-default generation parameters: {'max_length': 200}\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
      "Non-default generation parameters: {'max_length': 200}\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) instead. This warning will be raised to an exception in v4.41.\n",
      "Non-default generation parameters: {'max_length': 200}\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/py311/lib/python3.11/site-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1857\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1860\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1861\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1862\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1863\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1864\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/py311/lib/python3.11/site-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2200\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m   2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m     tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   2206\u001b[0m     args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m   2207\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m   2208\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m   2209\u001b[0m ):\n\u001b[1;32m   2210\u001b[0m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m   2211\u001b[0m     tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
      "File \u001b[0;32m~/anaconda3/envs/py311/lib/python3.11/site-packages/transformers/trainer.py:3147\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m   3145\u001b[0m         scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m   3146\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3147\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n",
      "File \u001b[0;32m~/anaconda3/envs/py311/lib/python3.11/site-packages/accelerate/accelerator.py:2124\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m   2122\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[1;32m   2123\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2124\u001b[0m     \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/py311/lib/python3.11/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    513\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    514\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    515\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    520\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    521\u001b[0m     )\n\u001b[0;32m--> 522\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    523\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    524\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/py311/lib/python3.11/site-packages/torch/autograd/__init__.py:266\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    261\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    263\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    264\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    265\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 266\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    267\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    268\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    269\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    270\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "dbac61a7-73a0-491a-82d6-66be377cb98f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/torch/utils/checkpoint.py:90: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Travel-Query']\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# tokenize\n",
    "text = \"意图识别任务: 还有双鸭山到淮阴的汽车票吗13号的\"\n",
    "encode_dict = tokenizer(text, max_length=512, padding='max_length',truncation=True)\n",
    "\n",
    "inputs = {\n",
    "  \"input_ids\": torch.tensor([encode_dict['input_ids']]).long().to(device),\n",
    "  \"attention_mask\": torch.tensor([encode_dict['attention_mask']]).long().to(device),\n",
    "  }\n",
    "\n",
    "# generate answer\n",
    "logits = model.generate(\n",
    "  input_ids = inputs['input_ids'],\n",
    "  max_length=100, \n",
    "  do_sample= True\n",
    "  # early_stopping=True,\n",
    "  )\n",
    "\n",
    "logits=logits[:,1:]\n",
    "predict_label = [tokenizer.decode(i,skip_special_tokens=True) for i in logits]\n",
    "print(predict_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "abd84409-f28e-45ca-87cd-9317f4ad2376",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_39398/516267186.py:6: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for train_text in tqdm_notebook(test_data[0].values):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ea36174c189407ab455121594585ae2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import tqdm_notebook\n",
    "\n",
    "model.eval()\n",
    "\n",
    "pred_label = []\n",
    "for train_text in tqdm_notebook(test_data[0].values):\n",
    "    text = f\"意图识别任务: {train_text}\"\n",
    "    encode_dict = tokenizer(text, max_length=512, padding='max_length', truncation=True)\n",
    "    \n",
    "    inputs = {\n",
    "      \"input_ids\": torch.tensor([encode_dict['input_ids']]).long().to(device),\n",
    "      \"attention_mask\": torch.tensor([encode_dict['attention_mask']]).long().to(device),\n",
    "    }\n",
    "    \n",
    "    # generate answer\n",
    "    logits = model.generate(\n",
    "        input_ids = inputs['input_ids'],\n",
    "        max_length=100, \n",
    "        do_sample= True\n",
    "    )\n",
    "    \n",
    "    logits = logits[:,:]\n",
    "    pred_label += [tokenizer.decode(i,skip_special_tokens=True) for i in logits]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "42416cbf-e174-48dd-bb4f-e189b3036bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame({\n",
    "    'ID': range(1, len(pred_label) + 1),\n",
    "    'Target': pred_label,\n",
    "}).to_csv('nlp_submit.csv', index=None)\n",
    "\n",
    "# 提交一下吧~"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0867df5a-9c60-4782-a0b1-b96beb338aa6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Video-Play',\n",
       " 'HomeAppliance-Control',\n",
       " 'Video-Play',\n",
       " 'Travel-Query',\n",
       " 'HomeAppliance-Control',\n",
       " 'FilmTele-Play',\n",
       " 'FilmTele-Play',\n",
       " 'Music-Play',\n",
       " 'Calendar-Query',\n",
       " 'Video-Play']"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_label[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b8edec4-4710-464d-b258-89cdf611bbfe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
